{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/1_train_bc.ipynb)\n",
    "# Train an Agent using Behavior Cloning\n",
    "\n",
    "Behavior cloning is the most naive approach to imitation learning. \n",
    "We take the transitions of trajectories taken by some expert and use them as training samples to train a new policy.\n",
    "The method has many drawbacks and often does not work. \n",
    "However in this example, where we use an agent for the seals/CartPole-v0 environment, it is feasible.\n",
    "\n",
    "Note that we use a variant of the CartPole environment from the seals package, which has fixed episode durations. Read more about why we do this [here](https://imitation.readthedocs.io/en/latest/main-concepts/variable_horizon.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First we need some kind of expert in CartPole so we can sample some expert trajectories.\n",
    "For convenience we just download one from the HuggingFace model hub.\n",
    "\n",
    "If you want to train an expert yourself have a look at the [training documenation](https://rl-baselines3-zoo.readthedocs.io/en/master/guide/train.html#basic-usage) of RL Baselines3 Zoo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import gymnasium as gym\n",
    "from imitation.policies.serialize import load_policy\n",
    "from imitation.util.util import make_vec_env\n",
    "from imitation.data.wrappers import RolloutInfoWrapper\n",
    "\n",
    "env = make_vec_env(\n",
    "    \"seals:seals/CartPole-v0\",\n",
    "    rng=np.random.default_rng(),\n",
    "    post_wrappers=[\n",
    "        lambda env, _: RolloutInfoWrapper(env)\n",
    "    ],  # needed for computing rollouts later\n",
    ")\n",
    "expert = load_policy(\n",
    "    \"ppo-huggingface\",\n",
    "    organization=\"HumanCompatibleAI\",\n",
    "    env_name=\"seals/CartPole-v0\",\n",
    "    venv=env,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's quickly check if the expert is any good.\n",
    "We usually should be able to reach a reward of 500, which is the maximum achievable value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "\n",
    "reward, _ = evaluate_policy(expert, env, 10)\n",
    "print(reward)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can use the expert to sample some trajectories.\n",
    "We flatten them right away since we are only interested in the individual transitions for behavior cloning.\n",
    "`imitation` comes with a number of helper functions that makes collecting those transitions really easy. First we collect 50 episode rollouts, then we flatten them to just the transitions that we need for training.\n",
    "\n",
    "Note that the rollout function requires a vectorized environment and needs the `RolloutInfoWrapper` around each of the environments. This is why we passed the `post_wrappers` argument to `make_vec_env` above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.data import rollout\n",
    "\n",
    "rng = np.random.default_rng()\n",
    "rollouts = rollout.rollout(\n",
    "    expert,\n",
    "    env,\n",
    "    rollout.make_sample_until(min_timesteps=None, min_episodes=50),\n",
    "    rng=rng,\n",
    ")\n",
    "transitions = rollout.flatten_trajectories(rollouts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's have a quick look at what we just generated using those library functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\n",
    "    f\"\"\"The `rollout` function generated a list of {len(rollouts)} {type(rollouts[0])}.\n",
    "After flattening, this list is turned into a {type(transitions)} object containing {len(transitions)} transitions.\n",
    "The transitions object contains arrays for: {', '.join(transitions.__dict__.keys())}.\"\n",
    "\"\"\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After we collected our transitions, it's time to set up our behavior cloning algorithm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.algorithms import bc\n",
    "\n",
    "bc_trainer = bc.BC(\n",
    "    observation_space=env.observation_space,\n",
    "    action_space=env.action_space,\n",
    "    demonstrations=transitions,\n",
    "    rng=rng,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see the untrained policy only gets poor rewards:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reward_before_training, _ = evaluate_policy(bc_trainer.policy, env, 10)\n",
    "print(f\"Reward before training: {reward_before_training}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After training, we can match the rewards of the expert (500):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bc_trainer.train(n_epochs=1)\n",
    "reward_after_training, _ = evaluate_policy(bc_trainer.policy, env, 10)\n",
    "print(f\"Reward after training: {reward_after_training}\")"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "bd378ce8f53beae712f05342da42c6a7612fc68b19bea03b52c7b1cdc8851b5f"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
